from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

from torch.nn.utils import weight_norm
import torchvision
from torchvision import transforms as T
import numpy as np
import satnet

class DigitConv(nn.Module):
    def __init__(self, dropout, activation_function):
        super(DigitConv, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500*2)
        self.fc2 = nn.Linear(500*2, 10)
        if activation_function == 'ReLU':
            self.act = nn.ReLU()
        elif activation_function == 'PReLU':
            self.act = nn.PReLU()
        elif activation_function == 'ELU':
            self.act = nn.ELU()
        elif activation_function == 'Softplus':
            self.act = nn.Softplus()
        self.dropout = dropout

    def forward(self, x):
        x = self.act(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = self.act(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        if self.dropout:
            x = F.dropout(x, p=0.2)
        x = self.act(self.fc1(x))
        if self.dropout:
            x = F.dropout(x, p=0.2)
        x = self.fc2(x)
        return x

# Code from https://github.com/kuangliu/pytorch-cifar/blob/master/models/lenet.py
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1   = nn.Linear(16*4*4, 120)
        self.fc2   = nn.Linear(120, 84)
        self.fc3   = nn.Linear(84, 10)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out

class SATNet(nn.Module):
    def __init__(self, m, aux, backbone, dropout, activation_function, output_layer, nonsatnet):
        super().__init__()
        self.nonsatnet = nonsatnet
        if nonsatnet:
            self.satnet_layer = nn.Sequential(
                                    nn.Linear(20,1000),
                                    nn.ReLU(),
                                    nn.Linear(1000,20),
                                    nn.Softmax(dim=1))
        else:
            self.satnet_layer = satnet.SATNet(20, m, aux)
        self.backbone = backbone
        if backbone == 'LeNet':
            self.backbone_layer = LeNet()
        elif backbone == 'Default':
            self.backbone_layer = DigitConv(dropout, activation_function)
        elif backbone == 'ResNet18':
            self.backbone_layer = nn.Sequential(torchvision.models.resnet18(pretrained=False, progress=True),
                                          nn.Linear(1000, 10))
        if output_layer == 'softmax':
            self.output_layer = nn.Softmax(dim=1)
        elif output_layer == 'sigmoid':
            self.output_layer = nn.Sigmoid()


    def forward(self, x):
        if self.backbone == 'ResNet18':
            x = x.expand(-1,3,-1,-1)
        x = self.backbone_layer(x)
        x = self.output_layer(x)
        zeros = torch.zeros_like(x)
        x = torch.cat([x, zeros], dim=1)
        if self.nonsatnet:
            x = self.satnet_layer(x)
        else:
            mask = torch.cat([torch.ones_like(zeros).int(), torch.zeros_like(zeros).int()], dim=1)
            x = self.satnet_layer(x.contiguous(), is_input=mask.contiguous())
        return x[:,10:]

def train(args, model, device, train_loader, optimizer_backbone, optimizer_satnet, epoch):
    model.train()
    correct = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer_backbone.zero_grad()
        optimizer_satnet.zero_grad()
        output = model(data)
        target_onehot = torch.zeros_like(output)
        target_onehot.scatter_(1, target.unsqueeze(1), 1)
        loss = F.binary_cross_entropy(output, target_onehot)
        loss.backward()
        optimizer_backbone.step()
        optimizer_satnet.step()
        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()
        # if batch_idx % args.log_interval == 0:
        #     print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        #         epoch, batch_idx * len(data), len(train_loader.dataset),
        #         100. * batch_idx / len(train_loader), loss.item()))
    train_acc = 100. * correct / len(train_loader.dataset)
    return train_acc

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            target_onehot = torch.zeros_like(output)
            target_onehot.scatter_(1, target.unsqueeze(1), 1)
            loss = F.binary_cross_entropy(output, target_onehot, reduction='sum').item()
            test_loss += loss # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    test_acc = 100. * correct / len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        test_acc))
    return test_acc

# code from https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
def str2bool(v):
    if isinstance(v, bool):
       return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=64, metavar='N')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N')
    parser.add_argument('--epochs', type=int, default=50, metavar='N')
    parser.add_argument('--no-cuda', action='store_true', default=False)
    parser.add_argument('--seed', type=int, default=1, metavar='S')
    parser.add_argument('--log-interval', type=int, default=100, metavar='N')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--m', type=int, default=100, metavar='M')
    parser.add_argument('--aux', type=int, default=100, metavar='A')
    parser.add_argument('--dropout', type=str2bool, default=False, metavar='D')
    parser.add_argument('--activation_function', type=str, default='ReLU', metavar='A',
                        choices=['ReLU', 'PReLU', 'ELU', 'Softplus'])
    parser.add_argument('--output_layer', type=str, default='softmax', metavar='O',
                        choices=['softmax', 'sigmoid'])
    parser.add_argument('--optimizer_backbone', type=str, default='Adam', metavar='O',
                        choices=['SGD', 'Adam'])
    parser.add_argument('--satnet_lr', type=float, default=0.1, metavar='LR')
    parser.add_argument('--backbone_lr', type=float, default=0.1, metavar='LR')
    parser.add_argument('--backbone', type=str, default='Default', metavar='B',
                        choices=['LeNet', 'Default', 'ResNet18'])
    parser.add_argument('--nonsatnet', type=str2bool, default=False, metavar='D')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    torch.manual_seed(args.seed)
    device = torch.device(args.device if use_cuda else "cpu")
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=args.batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=args.test_batch_size, shuffle=False, **kwargs)

    model = SATNet(args.m, args.aux, args.backbone, args.dropout, args.activation_function, args.output_layer, args.nonsatnet).to(device)
    num_params = sum([p.numel() for p in model.parameters()])
    print("Number of params: %d" % num_params)

    optimizer_satnet = optim.Adam(model.satnet_layer.parameters(), lr=args.satnet_lr)
    if args.optimizer_backbone == "SGD":
        optimizer_backbone = optim.SGD(model.backbone_layer.parameters(), lr=args.backbone_lr)
    elif args.optimizer_backbone == "Adam":
        optimizer_backbone = optim.Adam(model.backbone_layer.parameters(), lr=args.backbone_lr)

    train_accs = []
    test_accs = []
    for epoch in range(1, args.epochs + 1):
        train_acc = train(args, model, device, train_loader, optimizer_backbone, optimizer_satnet, epoch)
        train_accs.append(train_acc)
        test_acc = test(model, device, test_loader)
        test_accs.append(test_acc)

    results = {'train_accs': train_accs,
               'test_accs': test_accs,
               'num_params': num_params}
    if args.nonsatnet:
        torch.save(results, "logs/seed%d-non%s-dropout%s-act%s-out%s-opt%s-slr%f-blr%f-%s.dict" %
            (args.seed, args.nonsatnet, args.dropout, args.activation_function, args.output_layer,
            args.optimizer_backbone, args.satnet_lr, args.backbone_lr, args.backbone))
    else:
        torch.save(results, "logs/seed%d-m%d-aux%d-dropout%s-act%s-out%s-opt%s-slr%f-blr%f-%s.dict" %
            (args.seed, args.m, args.aux, args.dropout, args.activation_function, args.output_layer,
            args.optimizer_backbone, args.satnet_lr, args.backbone_lr, args.backbone))


if __name__ == '__main__':
    main()